import os
import sys
import gzip
from collections import Counter, defaultdict
import pysam
from scipy.stats import gaussian_kde
from Bio import SeqIO
from pylab import *


target = sys.argv[1]


directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters/"
filename = "%s.fa" % target
path = os.path.join(directory, filename)
stream  = open(path)
records = SeqIO.parse(stream, "fasta")
genes = {}
for record in records:
    description = record.description
    i = description.rfind("(") + 1
    j = description.rfind(")")
    gene = description[i:j]
    name = record.id
    genes[name] = gene
stream.close()

directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters/"
filename = "%s.psl" % target
path = os.path.join(directory, filename)
stream  = open(path)
transcript_lengths = {}
for line in stream:
    words = line.split()
    assert len(words) == 21
    qName = words[9]
    qSize = int(words[10])
    transcript_lengths[qName] = qSize
stream.close()

truncated = set()
directory = "/osc-fs_home/mdehoon/Data/CASPARs/MiSeq/Mapping"
filenames = os.listdir(directory)
filenames.sort()
counts = Counter()
for filename in filenames:
    library, extension = os.path.splitext(filename)
    assert extension == ".bam"
    if not library.startswith("t"):
        # include time course data only
        continue
    path = os.path.join(directory, filename)
    print("Reading", path)
    alignments = pysam.AlignmentFile(path)
    for alignment1 in alignments:
        alignment2 = next(alignments)
        if alignment1.is_unmapped:
            assert alignment2.is_unmapped
            continue
        assert not alignment2.is_unmapped
        if alignment1.get_tag("XT") != target:
            continue
        transcripts = alignment1.get_tag("XR").split(",")
        transcript_length = set([transcript_lengths[transcript] for transcript in transcripts])
        assert len(transcript_length) == 1
        transcript_length = transcript_length.pop()
        sequence_length = alignment1.get_tag("XL")
        counts[(transcript_length, sequence_length)] += 1
        if sequence_length < 0.8 * transcript_length:
            for transcript in transcripts:
                gene = genes[transcript]
                truncated.add(gene)

print("Number of genes with truncated transcripts: %d" % len(truncated))

data = []
for i, key in enumerate(counts):
    tSize, length = key
    count = counts[key]
    row = [tSize, length, count]
    data.append(row)

data = array(data)
data = transpose(data)

kde = gaussian_kde(log(data[:2,:]), weights=data[2])
z = kde(log(data[:2,:]))
indices = argsort(z)
x, y = data[:2, indices]
z = z[indices]

scatter(x, y, c=z, edgecolor=None)
xscale('log')
yscale('log')

xmin, xmax = xlim()
ymin, ymax = ylim()
xmin = min(xmin, 50)
xmax = max(xmax, 300)
ymin = min(ymin, 50)
ymax = max(ymax, 300)
minimum = min(xmin, ymin)
maximum = max(xmax, ymax)
plot([xmin, xmax], [72, 72], "r--")
plot([xmin, xmax], [272, 272], "r--")
plot([minimum, maximum], [minimum, maximum], color='red')
xlabel("%s size [nucleotides]" % target)
ylabel("transcript size [nucleotides]")
title(target, size=12)
xlim(xmin, xmax)
ylim(ymin, ymax)

filename = "figure_scatter_sizes_%s_timecourse.svg" % target
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_scatter_sizes_%s_timecourse.png" % target
print("Saving figure to %s" % filename)
savefig(filename)
